#!/usr/bin/env python3
# D9 (REAL v4) — E-Proxy + Gauss-Plateau
# Control: DDA 1/r scheduler (integer, deterministic), rails OFF, no curve weights, no RNG.
# Readouts:
#   - SLOPE: per-cell rate vs (R_eff / r) on LOG-spaced annuli (mid-window linear fit in log-log space).
#   - PLATEAU: total counts per annulus per tick on FULL linear bins (equal Δshell, trailing partial bin dropped),
#              CV evaluated over outer_frac of those full bins.

import argparse, csv, hashlib, json, math
from pathlib import Path

# ---------- utils ----------
def ensure_dir(p: Path): p.mkdir(parents=True, exist_ok=True)
def sha256_of_file(p: Path):
    import hashlib
    h = hashlib.sha256()
    with p.open('rb') as f:
        for chunk in iter(lambda: f.read(1<<20), b''):
            h.update(chunk)
    return h.hexdigest()
def write_json(p: Path, obj): ensure_dir(p.parent); p.write_text(json.dumps(obj, indent=2), encoding='utf-8')
def write_csv(p: Path, header, rows):
    ensure_dir(p.parent)
    # corrected open syntax: specify mode once
    with p.open('w', newline='', encoding='utf-8') as f:
        w = csv.writer(f); w.writerow(header); w.writerows(rows)

# ---------- shells & control (DDA 1/r) ----------
def build_shells(nx, ny, outer_margin):
    cx = (nx - 1) * 0.5
    cy = (ny - 1) * 0.5
    R_eff = min(nx, ny) * 0.5 - outer_margin
    if R_eff <= 1: R_eff = max(nx, ny) * 0.25
    S_max = max(1, int(R_eff))
    cells_total = [0] * (S_max + 1)
    for y in range(ny):
        for x in range(nx):
            r = ((x + 0.5 - cx) ** 2 + (y + 0.5 - cy) ** 2) ** 0.5
            s = max(1, int(r))
            if s <= S_max:
                cells_total[s] += 1
    return cells_total, S_max, R_eff

def run_engine(nx, ny, H, outer_margin, rate_num):
    cells_total, S_max, R_eff = build_shells(nx, ny, outer_margin)
    acc = [0] * (S_max + 1)
    commits_in_shell = [0] * (S_max + 1)
    for _ in range(H):
        for s in range(1, S_max + 1):
            acc[s] += rate_num
            if acc[s] >= s:
                acc[s] -= s
                commits_in_shell[s] += cells_total[s]  # rails OFF: whole shell fires
    per_cell_rate = [0.0] * (S_max + 1)
    per_shell_total_per_tick = [0.0] * (S_max + 1)
    for s in range(1, S_max + 1):
        C = cells_total[s]
        if C > 0:
            s_rate = commits_in_shell[s] / (C * H)
            per_cell_rate[s] = s_rate
            per_shell_total_per_tick[s] = commits_in_shell[s] / H
    return {
        "S_max": S_max,
        "R_eff": R_eff,
        "cells_total": cells_total,
        "per_cell_rate": per_cell_rate,
        "per_shell_total_per_tick": per_shell_total_per_tick
    }

# ---------- annuli: LOG edges for slope, FULL linear bins for plateau ----------
def make_log_edges(R_eff, n_log):
    r_min = max(1.0, R_eff / 16.0)
    r_max = max(r_min + 1.0, R_eff)
    return [r_min * (r_max / r_min) ** (k / float(n_log)) for k in range(n_log + 1)]

def aggregate_for_slope(per_cell_rate, cells_total, log_edges, R_eff):
    m = len(log_edges) - 1
    r_center = [0.0] * m
    inv_r = [0.0] * m
    rate_ann = [0.0] * m
    for i in range(m):
        r_lo, r_hi = log_edges[i], log_edges[i + 1]
        s_lo = max(1, int(r_lo))
        s_hi = max(s_lo, int(r_hi))
        tot_cells = 0
        acc_rate = 0.0
        for s in range(s_lo, s_hi + 1):
            c = cells_total[s]
            acc_rate += per_cell_rate[s] * c
            tot_cells += c
        r_center[i] = 0.5 * (r_lo + r_hi)
        inv_r[i] = R_eff / max(1e-9, r_center[i])
        rate_ann[i] = (acc_rate / tot_cells) if tot_cells > 0 else 0.0
    return inv_r, rate_ann, r_center

def make_full_linear_bins(S_max, shells_per_bin):
    """Return list of (s_lo, s_hi) where every bin has exactly shells_per_bin shells.
       Drop trailing partial bin to avoid CV inflation."""
    bins = []
    full = (S_max // shells_per_bin) * shells_per_bin
    s = 1
    while s + shells_per_bin - 1 <= full:
        bins.append((s, s + shells_per_bin - 1))
        s += shells_per_bin
    return bins

def aggregate_for_plateau(per_shell_total, lin_bins):
    totals = []
    for s_lo, s_hi in lin_bins:
        totals.append(sum(per_shell_total[s] for s in range(s_lo, s_hi + 1)))
    return totals

# ---------- stats ----------
def fit_loglog_slope(x, y, mid_frac=0.60):
    n = len(x)
    lo = max(0, int((1.0 - mid_frac) * 0.5 * n))
    hi = max(lo + 2, int(n - lo))
    X = [math.log(max(1e-12, x[i])) for i in range(lo, hi)]
    Y = [math.log(max(1e-12, y[i])) for i in range(lo, hi)]
    m = len(X)
    if m < 2:
        return {"slope": float('nan'), "r2": 0.0, "lo_idx": lo, "hi_idx": hi}
    mx = sum(X) / m; my = sum(Y) / m
    sx = sum((u - mx) ** 2 for u in X)
    sxy = sum((X[i] - mx) * (Y[i] - my) for i in range(m))
    slope = sxy / sx if sx > 0 else float('nan')
    ss_tot = sum((u - my) ** 2 for u in Y)
    ss_res = sum((Y[i] - (my + slope * (X[i] - mx))) ** 2 for i in range(m))
    r2 = 1.0 - (ss_res / ss_tot if ss_tot > 0 else 0.0)
    return {"slope": slope, "r2": r2, "lo_idx": lo, "hi_idx": hi}

def plateau_cv_ok(totals, outer_frac, cv_max):
    n = len(totals)
    if n == 0:
        return False, {"cv": float('inf'), "k": 0}
    k = max(1, int(round(n * outer_frac)))
    seg = totals[-k:]
    mean = sum(seg) / len(seg)
    var = sum((t - mean) ** 2 for t in seg) / len(seg)
    sd = var ** 0.5
    cv = sd / mean if mean > 0 else float('inf')
    return (cv <= cv_max), {"cv": cv, "k": k}

# ---------- one panel run ----------
def run_one(manifest_path: Path, diag_path: Path, out_dir: Path):
    ensure_dir(out_dir / 'metrics'); ensure_dir(out_dir / 'audits'); ensure_dir(out_dir / 'run_info')
    manifest = json.loads(manifest_path.read_text('utf-8'))
    diag = json.loads(diag_path.read_text('utf-8'))

    nx = int(manifest["domain"]["grid"]["nx"])
    ny = int(manifest["domain"]["grid"]["ny"])
    H  = int(manifest["domain"]["ticks"])

    outer_margin = int(diag["ring"]["outer_margin"])
    rate_num = int(diag["controls"]["rate_num"])
    n_log = int(diag["annuli"]["n_log"])
    shells_per_bin = int(diag["annuli"]["lin_shells_per_bin"])
    slope_abs_tol = float(diag["tolerances"]["slope_abs_tol"])
    r2_min = float(diag["tolerances"]["r2_min"])
    outer_frac = float(diag["tolerances"]["plateau_outer_frac"])
    cv_max = float(diag["tolerances"]["plateau_cv_max"])

    res = run_engine(nx, ny, H, outer_margin, rate_num)

    # SLOPE (log bins)
    inv_r, rate_ann, r_center = aggregate_for_slope(res["per_cell_rate"], res["cells_total"],
                                                    make_log_edges(res["R_eff"], n_log),
                                                    res["R_eff"])
    fit = fit_loglog_slope(inv_r, rate_ann, mid_frac=0.60)
    slope_ok = (abs(fit["slope"] - 1.0) <= slope_abs_tol) and (fit["r2"] >= r2_min)

    # PLATEAU (linear full bins)
    lin_bins = make_full_linear_bins(res["S_max"], shells_per_bin)
    totals = aggregate_for_plateau(res["per_shell_total_per_tick"], lin_bins)
    plat_ok, plat_meta = plateau_cv_ok(totals, outer_frac, cv_max)

    # write metrics
    write_csv(out_dir/'metrics'/'slope_panel_log.csv',
              ['idx','r_center','inv_r','per_cell_rate'],
              [[i, r_center[i], inv_r[i], rate_ann[i]] for i in range(len(r_center))])
    write_csv(out_dir/'metrics'/'plateau_panel_linear_fullbins.csv',
              ['bin_id','shell_lo','shell_hi','total_per_tick'],
              [[i, *lin_bins[i], totals[i]] for i in range(len(lin_bins))])
    write_json(out_dir/'metrics'/'slope_fit.json', fit)
    write_json(out_dir/'audits'/'em_eproxy.json',
               {"slope_ok": slope_ok, "slope": fit["slope"], "r2": fit["r2"],
                "plateau_ok": plat_ok, "plateau_cv": plat_meta["cv"],
                "plateau_outer_frac": outer_frac, "plateau_cv_max": cv_max,
                "PASS": bool(slope_ok and plat_ok)})

    # provenance
    write_json(out_dir/'run_info'/'hashes.json',
               {"manifest_hash": sha256_of_file(manifest_path),
                "diag_hash": sha256_of_file(diag_path),
                "engine_entrypoint": f"python {manifest_path.name} -> engine_d9_eproxy_v4.py"})

# ---------- CLI ----------
if __name__ == '__main__':
    import json, sys
    p = argparse.ArgumentParser()
    p.add_argument('--manifest', required=True)
    p.add_argument('--diag', required=True)
    p.add_argument('--out', required=True)
    a = p.parse_args()
    try:
        run_one(Path(a.manifest), Path(a.diag), Path(a.out))
        print("D9 v4 DONE")
    except Exception as e:
        ensure_dir(Path(a.out)/'audits')
        write_json(Path(a.out)/'audits'/'em_eproxy.json',
                   {"PASS": False, "failure_reason": f"{type(e).__name__}: {e}"})
        raise
